在了解Estimator API的部分,我們將會學到:
Estimator API是TensorFlow API中最高的階級(參考Day11中TensorFlow API 的階層),使用Estimator API有下列優點:
下面我們用例子來說明Estimator API的用法。
如果今天要建構一個線性迴歸的模型,除了從頭開始撰寫之外,使用 tf.estimator.LinearRegressor()
可以更快速的建構出來,如下面程式碼:
宣告好之後就是訓練和預測:
以預測房地產價格為例,完整個程式碼可寫成:
這樣短短幾行,我們就完成了線性迴歸模型的建置,當然若要建構更複雜如DNN的模型也可以改成使用 tf.estimator.DNNRegressor()
,如下圖:
使用checkpoint大致上可分為三種情況:
checkpoint使用方式如下圖,在宣告estimator內加入要存放checkpoint的路徑便可。
透過這樣的方式,在訓練的時候就會把checkpoint存下來,若該路徑本來就有checkpoint的話,就會該checkpoint繼續開始訓練。
這邊以一般常用的numpy和pandas資料類型舉例,只要使用 tf.estimator.inputs.numpy_input_fn()
或 tf.estimator.inputs.pandas_input_fn()
將資料x,y定義好,訓練的一些參數如批次大小、epoch數、要不要隨機洗亂資料...等等,就可以供後面模型訓練的輸入做使用。
全部模型訓練的程式大概如下面所示, model.train(train_input_fn(XXX))
這裡就是我們給入的記憶體內資料:
接著我們用實作來更詳細的看看上面提到的部分。
在這個實作中,我們將學會:
登入GCP,開啟Notebooks後,複製課程 Github repo (如Day9的Part 1 & 2步驟)。
在左邊的資料夾結構,點進 training-data-analyst > courses > machine_learning > deepdive > 03_tensorflow,然後打開檔案 b_estimator.ipynb。
首先先將資料import進來,這邊的資料只有一部分之前用到的紐約計程車資料(7700筆)。
shuffle = True
,但是 eval 就不需要所以是 shuffle = False
:y = None
:在這個Lab我們雖然訓練了一個線性迴歸模型,但表現卻比不上用簡單的經驗和直覺計算出來的結果,別擔心,後面的章節和實作將會介紹到怎麼讓ML模型的表現更好!
今天介紹了Estimator API,明天我們將介紹到 “如何在巨大的資料集上做訓練”。